import torch
import os
from DQN_RNN import Critic


class VDN:
    def __init__(self, args):
        self.args = args
        self.train_step = 0

        # create the network
        self.critic_network = Critic(args)

        self.critic_target_network = Critic(args)


        self.critic_target_network.load_state_dict(self.critic_network.state_dict())
        if self.args.cuda:

            self.critic_network = self.critic_network.cuda()

            self.critic_target_network = self.critic_target_network.cuda()


        self.critic_optim = torch.optim.Adam(self.critic_network.parameters(), lr=self.args.lr, eps=1.5e-05)

        # create the dict for store the model
        if not os.path.exists(self.args.save_dir):
            os.makedirs(self.args.save_dir)
        # path to save the model
        self.model_path = self.args.save_dir
        if not os.path.exists(self.model_path):
            os.makedirs(self.model_path)
        self.model_path = self.model_path
        if not os.path.exists(self.model_path):
            os.makedirs(self.model_path)

        # 加载模型
        if os.path.exists(self.model_path + '/actor_params.pkl'):
            self.critic_network.load_state_dict(torch.load(self.model_path + '/critic_params.pkl'))
            print('successfully loaded critic_network: {}'.format( self.model_path + '/critic_params.pkl'))

    # soft update
    def _update_target_network(self):
        self.critic_target_network.load_state_dict(self.critic_network.state_dict())

    # update the network
    def loss(self, transitions):

        batch_size = transitions["o"].shape[0]
        episode_length = transitions["o"].shape[1]
        assert episode_length == self.args.max_episode_len, (episode_length, self.args.max_episode_len)

        if self.args.scenario_name in ['GuessingNumber']:
            assert transitions["o"].shape == (batch_size, episode_length, self.args.n_agents, self.args.obs_shape), transitions["o"].shape
            o = transitions["o"].transpose(0, 1)
        elif self.args.scenario_name == 'RevealingGoal':
            assert transitions["o"].shape == (batch_size,
                                              episode_length,
                                              self.args.n_agents,
                                              self.args.obs_shape[0],
                                              self.args.obs_shape[1],
                                              self.args.obs_shape[2]), transitions["o"].shape

            o = transitions["o"].transpose(0, 1)
        else:
            raise NotImplementedError

        u = transitions["u"].transpose(0, 1)
        a_u = transitions["a_u"].transpose(0, 1)
        r = transitions["r"].transpose(0, 1)
        done = transitions["done"].transpose(0, 1)

        hidden = self.init_hidden(batch_size*self.args.n_agents)

        q_value, _ = self.critic_network(o.reshape(episode_length * batch_size * self.args.n_agents, -1),
                                         a_u.reshape(episode_length * batch_size * self.args.n_agents, -1),
                                         hidden,
                                         batch_size=batch_size*self.args.n_agents,
                                         select_id=u.reshape(episode_length * batch_size * self.args.n_agents, -1).long())
        sum_q = q_value.reshape(episode_length, batch_size, -1).sum(dim=-1)

        with torch.no_grad():
            hidden = self.init_hidden(batch_size * self.args.n_agents)
            _, q_target, _ = self.critic_target_network(o.reshape(episode_length * batch_size * self.args.n_agents, -1),
                                                        a_u.reshape(episode_length * batch_size * self.args.n_agents, -1),
                                                        hidden,
                                                        batch_size=batch_size*self.args.n_agents)

            q_next = torch.cat([q_target.reshape(episode_length, batch_size, -1)[1:],
                                torch.zeros_like(q_target.reshape(episode_length, batch_size, -1)[0].unsqueeze(0))],
                               dim=0)
            sum_q_next = q_next.reshape(episode_length, batch_size, -1).sum(dim=-1)

            target_q = (r + self.args.gamma * sum_q_next * (1 - done))

        # the q loss
        pad = torch.cat([done[0].unsqueeze(0), done[:-1]], dim=0)
        critic_loss = ((target_q - sum_q) * (1 - pad)).pow(2).mean()

        return critic_loss

    def train(self, loss):
        self.critic_optim.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.critic_network.parameters(), self.args.grad_clip)
        self.critic_optim.step()

        if self.train_step > 0 and self.train_step % self.args.target_update_rate == 0:
            self._update_target_network()

        if self.train_step > 0 and self.train_step % self.args.save_rate == 0:
            self.save_model(self.train_step)
        self.train_step += 1

    def save_model(self, train_step):
        num = str(train_step)
        model_path = os.path.join(self.args.save_dir, "model")
        if not os.path.exists(model_path):
            os.makedirs(model_path)
        model_path = os.path.join(model_path)
        if not os.path.exists(model_path):
            os.makedirs(model_path)
        torch.save(self.critic_network.state_dict(),  model_path + '/' + num + '_critic_params.pkl')

    def init_hidden(self, batch_size):
        if self.args.cuda:
            return torch.zeros(self.args.rnn_layer, batch_size, self.args.hidden_size, dtype=torch.float32).cuda()
        else:
            return torch.zeros(self.args.rnn_layer, batch_size, self.args.hidden_size, dtype=torch.float32)


